{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmarking EHRShot with XGBoost\n",
    " \n",
    "This notebook demonstrates how to benchmark the EHRShot dataset (https://som-shahlab.github.io/ehrshot-website/) using XGBoost for various clinical prediction tasks. EHRShot is a comprehensive dataset containing patient records from multiple OMOP tables, including:\n",
    "\n",
    "- Operational outcomes (ICU admission, length of stay, readmission)\n",
    "- Lab value predictions (anemia, hyperkalemia, hypoglycemia, etc.)\n",
    "- New diagnosis predictions (hypertension, hyperlipidemia, cancer, etc.)\n",
    " \n",
    "We'll use XGBoost as our baseline model to:\n",
    "1. Process the EHRShot data into a feature matrix\n",
    "2. Train and evaluate models for each prediction task\n",
    "3. Compare performance across different OMOP tables\n",
    " \n",
    "The results will help establish baseline performance metrics for these clinical prediction tasks.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding PyHealth to sys.path: /home/zw12/PyHealth\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "pyhealth_path = os.path.dirname(os.getcwd())\n",
    "if pyhealth_path not in sys.path:\n",
    "    print(f\"Adding PyHealth to sys.path: {pyhealth_path}\")\n",
    "    sys.path.insert(0, pyhealth_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import xgboost as xgb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from pyhealth.datasets import EHRShotDataset\n",
    "from pyhealth.tasks import BenchmarkEHRShot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and Prepare the Dataset\n",
    "\n",
    "First, we'll load the EHRShot dataset and prepare it for our benchmarking tasks. The dataset contains various clinical prediction tasks organized across different OMOP tables. We'll load all the necessary tables including:\n",
    "\n",
    "- Base tables: `ehrshot` and `splits`\n",
    "- Operational outcomes: `guo_icu`, `guo_los`, `guo_readmission`\n",
    "- Lab value predictions: `lab_anemia`, `lab_hyperkalemia`, `lab_hypoglycemia`, etc.\n",
    "- New diagnosis predictions: `new_acutemi`, `new_celiac`, `new_hyperlipidemia`, etc.\n",
    "\n",
    "After loading, we'll examine the dataset statistics to understand its size and composition.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No config path provided, using default config\n",
      "Initializing ehrshot dataset from /shared/eng/EHRSHOT_ASSETS (dev mode: False)\n",
      "Scanning table: ehrshot from /shared/eng/EHRSHOT_ASSETS/data/ehrshot.csv\n",
      "Scanning table: splits from /shared/eng/EHRSHOT_ASSETS/splits/person_id_map.csv\n",
      "Scanning table: guo_icu from /shared/eng/EHRSHOT_ASSETS/benchmark/guo_icu/labeled_patients.csv\n",
      "Scanning table: guo_los from /shared/eng/EHRSHOT_ASSETS/benchmark/guo_los/labeled_patients.csv\n",
      "Scanning table: guo_readmission from /shared/eng/EHRSHOT_ASSETS/benchmark/guo_readmission/labeled_patients.csv\n",
      "Scanning table: lab_anemia from /shared/eng/EHRSHOT_ASSETS/benchmark/lab_anemia/labeled_patients.csv\n",
      "Scanning table: lab_hyperkalemia from /shared/eng/EHRSHOT_ASSETS/benchmark/lab_hyperkalemia/labeled_patients.csv\n",
      "Scanning table: lab_hypoglycemia from /shared/eng/EHRSHOT_ASSETS/benchmark/lab_hypoglycemia/labeled_patients.csv\n",
      "Scanning table: lab_hyponatremia from /shared/eng/EHRSHOT_ASSETS/benchmark/lab_hyponatremia/labeled_patients.csv\n",
      "Scanning table: lab_thrombocytopenia from /shared/eng/EHRSHOT_ASSETS/benchmark/lab_thrombocytopenia/labeled_patients.csv\n",
      "Scanning table: new_acutemi from /shared/eng/EHRSHOT_ASSETS/benchmark/new_acutemi/labeled_patients.csv\n",
      "Scanning table: new_celiac from /shared/eng/EHRSHOT_ASSETS/benchmark/new_celiac/labeled_patients.csv\n",
      "Scanning table: new_hyperlipidemia from /shared/eng/EHRSHOT_ASSETS/benchmark/new_hyperlipidemia/labeled_patients.csv\n",
      "Scanning table: new_hypertension from /shared/eng/EHRSHOT_ASSETS/benchmark/new_hypertension/labeled_patients.csv\n",
      "Scanning table: new_lupus from /shared/eng/EHRSHOT_ASSETS/benchmark/new_lupus/labeled_patients.csv\n",
      "Scanning table: new_pancan from /shared/eng/EHRSHOT_ASSETS/benchmark/new_pancan/labeled_patients.csv\n",
      "Collecting global event dataframe...\n",
      "Collected dataframe with shape: (42820755, 37)\n",
      "Dataset: ehrshot\n",
      "Dev mode: False\n",
      "Number of patients: 6739\n",
      "Number of events: 42820755\n"
     ]
    }
   ],
   "source": [
    "dataset = EHRShotDataset(\n",
    "    root=\"/shared/eng/EHRSHOT_ASSETS\",\n",
    "    tables=[\n",
    "        \"ehrshot\",\n",
    "        \"splits\",\n",
    "        \"guo_icu\",\n",
    "        \"guo_los\",\n",
    "        \"guo_readmission\",\n",
    "        \"lab_anemia\",\n",
    "        \"lab_hyperkalemia\",\n",
    "        \"lab_hypoglycemia\",\n",
    "        \"lab_hyponatremia\",\n",
    "        \"lab_thrombocytopenia\",\n",
    "        \"new_acutemi\",\n",
    "        \"new_celiac\",\n",
    "        \"new_hyperlipidemia\",\n",
    "        \"new_hypertension\",\n",
    "        \"new_lupus\",\n",
    "        \"new_pancan\",\n",
    "    ],\n",
    ")\n",
    "dataset.stats()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Up Benchmark Tasks\n",
    " \n",
    "Next, we'll define the benchmark tasks we want to evaluate. The EHRShot dataset supports multiple clinical prediction tasks across different OMOP tables. We'll set up:\n",
    " \n",
    "- Operational outcomes: ICU admission, length of stay, and readmission prediction\n",
    "- Lab value predictions: anemia, hyperkalemia, hypoglycemia, hyponatremia, and thrombocytopenia\n",
    "- New diagnosis predictions: acute MI, celiac disease, hyperlipidemia, hypertension, lupus, and pancreatic cancer\n",
    " \n",
    "For each task, we'll select a subset of OMOP tables to construct features and labels.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = [\n",
    "    \"guo_icu\",\n",
    "    \"guo_los\",\n",
    "    \"guo_readmission\",\n",
    "    \"lab_anemia\",\n",
    "    \"lab_hyperkalemia\",\n",
    "    \"lab_hypoglycemia\",\n",
    "    \"lab_hyponatremia\",\n",
    "    \"lab_thrombocytopenia\",\n",
    "    \"new_acutemi\",\n",
    "    \"new_celiac\",\n",
    "    \"new_hyperlipidemia\",\n",
    "    \"new_hypertension\",\n",
    "    \"new_lupus\",\n",
    "    \"new_pancan\"\n",
    "]\n",
    "omop_tables = [\n",
    "    \"device_exposure\",\n",
    "    \"person\",\n",
    "    \"visit_detail\",\n",
    "    \"visit_occurrence\",\n",
    "    \"condition_occurrence\",\n",
    "    \"procedure_occurrence\",\n",
    "    \"note\",\n",
    "    \"drug_exposure\",\n",
    "    \"observation\",\n",
    "    \"measurement\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"guo_icu\"\n",
    "omop_table = [\"condition_occurrence\", \"procedure_occurrence\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting task BenchmarkEHRShot/guo_icu for ehrshot base dataset...\n",
      "Generating samples with 8 worker(s)...\n",
      "Generating samples for BenchmarkEHRShot/guo_icu\n",
      "Label label vocab: {'False': 0, 'True': 1}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing samples: 100%|██████████| 6491/6491 [00:01<00:00, 6456.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated 6491 samples for task BenchmarkEHRShot/guo_icu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "task_fn = BenchmarkEHRShot(task=task, omop_tables=omop_table)\n",
    "samples = dataset.set_task(task_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SequenceProcessor(code_vocab_size=15148)\n",
      "BinaryLabelProcessor(label_vocab_size=2)\n"
     ]
    }
   ],
   "source": [
    "print(samples.input_processors[\"feature\"])\n",
    "print(samples.output_processors[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Features and Labels\n",
    "\n",
    "In this step, we prepare the data for XGBoost by transforming the clinical samples into feature matrices and label vectors.\n",
    "\n",
    "For features, we create multi-hot encoded vectors where each position represents a unique clinical event (e.g., diagnoses, procedures).\n",
    "The vector length is determined by the total number of unique events (num_features), and a '1' indicates the presence of that event.\n",
    "\n",
    "For labels, we extract the target values from each sample. For lab-related tasks, we convert multi-class labels to binary by\n",
    "setting any positive class (>=1) to 1.\n",
    "\n",
    "The data is then split into training, validation, and test sets based on the predefined splits in the samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: (2402, 15147), (2402,)\n",
      "Val set: (2052, 15147), (2052,)\n",
      "Test set: (2037, 15147), (2037,)\n"
     ]
    }
   ],
   "source": [
    "# Determine the maximum index to size the multi-hot vectors\n",
    "num_features = samples.input_processors[\"feature\"]._next_index\n",
    "\n",
    "# Prepare feature matrix and labels\n",
    "X, y, splits = [], [], []\n",
    "\n",
    "for sample in samples:\n",
    "    vec = np.zeros(num_features, dtype=int)\n",
    "    vec[sample[\"feature\"].numpy()] = 1  # multi-hot encoding\n",
    "    X.append(vec)\n",
    "    y.append(sample[\"label\"].item())\n",
    "    splits.append(sample[\"split\"])\n",
    "\n",
    "X = np.array(X)\n",
    "y = np.array(y)\n",
    "splits = np.array(splits)\n",
    "\n",
    "if task.startswith(\"lab_\"):\n",
    "    print(\"Converting multi-class to binary\")\n",
    "    # convert multiclass to binary\n",
    "    y[y >= 1] = 1\n",
    "\n",
    "X_train = X[splits == \"train\"]\n",
    "y_train = y[splits == \"train\"]\n",
    "X_val = X[splits == \"val\"]\n",
    "y_val = y[splits == \"val\"]\n",
    "X_test = X[splits == \"test\"]\n",
    "y_test = y[splits == \"test\"]\n",
    "print(f\"Train set: {X_train.shape}, {y_train.shape}\")\n",
    "print(f\"Val set: {X_val.shape}, {y_val.shape}\")\n",
    "print(f\"Test set: {X_test.shape}, {y_test.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Evaluate XGBoost Model\n",
    " \n",
    "In this step, we will train an XGBoost classifier on our prepared feature matrices and labels. \n",
    " \n",
    "We'll use the training set to fit the model and evaluate its performance on the test set using the area under the ROC curve (AUC) metric. \n",
    "\n",
    "The model will be configured with standard hyperparameters including 100 trees, a learning rate of 0.1, and log loss as the evaluation metric. \n",
    " \n",
    "This will allow us to assess how well the model can predict the target outcomes using the encoded clinical features.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zw12/miniconda3/envs/pytorch20/lib/python3.9/site-packages/xgboost/core.py:158: UserWarning: [22:09:32] WARNING: /workspace/src/learner.cc:740: \n",
      "Parameters: { \"use_label_encoder\" } are not used.\n",
      "\n",
      "  warnings.warn(smsg, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test AUC for task guo_icu with input ['condition_occurrence', 'procedure_occurrence']: 0.7543\n"
     ]
    }
   ],
   "source": [
    "model = xgb.XGBClassifier(\n",
    "    n_estimators=100,\n",
    "    learning_rate=0.1,\n",
    "    use_label_encoder=False,\n",
    "    eval_metric=\"logloss\",\n",
    "    random_state=42,\n",
    "    n_jobs=-1\n",
    ")\n",
    "model.fit(X_train, y_train, verbose=True)\n",
    "\n",
    "# Test set evaluation\n",
    "y_pred_prob = model.predict_proba(X_test)[:, 1]\n",
    "auc = roc_auc_score(y_test, y_pred_prob)\n",
    "print(f\"Test AUC for task {task} with input {omop_table}: {auc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch20",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
